/*
 * Copyright The OpenTelemetry Authors
 * SPDX-License-Identifier: Apache-2.0
 */

package io.opentelemetry.javaagent.instrumentation.httpurlconnection;

import static io.opentelemetry.javaagent.extension.matcher.AgentElementMatchers.extendsClass;
import static io.opentelemetry.javaagent.instrumentation.httpurlconnection.HttpUrlConnectionSingletons.HTTP_URL_STATE;
import static io.opentelemetry.javaagent.instrumentation.httpurlconnection.HttpUrlConnectionSingletons.instrumenter;
import static net.bytebuddy.matcher.ElementMatchers.isProtected;
import static net.bytebuddy.matcher.ElementMatchers.isPublic;
import static net.bytebuddy.matcher.ElementMatchers.nameStartsWith;
import static net.bytebuddy.matcher.ElementMatchers.named;
import static net.bytebuddy.matcher.ElementMatchers.namedOneOf;
import static net.bytebuddy.matcher.ElementMatchers.not;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.javaagent.bootstrap.CallDepth;
import io.opentelemetry.javaagent.extension.instrumentation.TypeInstrumentation;
import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer;
import java.net.HttpURLConnection;
import javax.annotation.Nullable;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.matcher.ElementMatcher;

public class HttpUrlConnectionInstrumentation implements TypeInstrumentation {
  @Override
  public ElementMatcher<TypeDescription> typeMatcher() {
    return nameStartsWith("java.net.")
        .or(nameStartsWith("sun.net"))
        // In WebLogic, URL.openConnection() returns its own internal implementation of
        // HttpURLConnection, which does not delegate the methods that have to be instrumented to
        // the JDK superclass. Therefore, it needs to be instrumented directly.
        .or(named("weblogic.net.http.HttpURLConnection"))
        // This class is a simple delegator. Skip because it does not update its `connected`
        // field.
        .and(not(named("sun.net.www.protocol.https.HttpsURLConnectionImpl")))
        .and(extendsClass(named("java.net.HttpURLConnection")));
  }

  @Override
  public void transform(TypeTransformer transformer) {
    transformer.applyAdviceToMethod(
        isPublic()
            .and(namedOneOf("connect", "getOutputStream", "getInputStream"))
            // ibm https url connection does not delegate connect, it calls plainConnect instead
            .or(isProtected().and(named("plainConnect"))),
        this.getClass().getName() + "$HttpUrlConnectionAdvice");
    transformer.applyAdviceToMethod(
        isPublic().and(named("getResponseCode")),
        this.getClass().getName() + "$GetResponseCodeAdvice");
  }

  @SuppressWarnings("unused")
  public static class HttpUrlConnectionAdvice {

    public static class AdviceScope {
      private final CallDepth callDepth;
      private final HttpUrlState httpUrlState;
      private final Scope scope;

      private AdviceScope(CallDepth callDepth, HttpUrlState httpUrlState, Scope scope) {
        this.callDepth = callDepth;
        this.httpUrlState = httpUrlState;
        this.scope = scope;
      }

      public static AdviceScope start(CallDepth callDepth, HttpURLConnection connection) {
        if (callDepth.getAndIncrement() > 0) {
          // only want the rest of the instrumentation rules (which are complex enough) to apply to
          // top-level HttpURLConnection calls
          return new AdviceScope(callDepth, null, null);
        }

        Context parentContext = Context.current();
        if (!instrumenter().shouldStart(parentContext, connection)) {
          return new AdviceScope(callDepth, null, null);
        }

        // using virtual field for a couple of reasons:
        // - to start an operation in connect() and end it in getInputStream()
        // - to avoid creating a new operation on multiple subsequent calls to getInputStream()
        HttpUrlState httpUrlState = HTTP_URL_STATE.get(connection);

        if (httpUrlState != null) {
          if (!httpUrlState.finished) {
            return new AdviceScope(callDepth, httpUrlState, httpUrlState.context.makeCurrent());
          }
          return new AdviceScope(callDepth, httpUrlState, null);
        }

        Context context = instrumenter().start(parentContext, connection);
        httpUrlState = new HttpUrlState(context);
        HTTP_URL_STATE.set(connection, httpUrlState);
        return new AdviceScope(callDepth, httpUrlState, context.makeCurrent());
      }

      public void end(
          HttpURLConnection connection,
          int responseCode,
          @Nullable Throwable throwable,
          String methodName) {
        if (callDepth.decrementAndGet() > 0 || scope == null) {
          return;
        }

        // prevent infinite recursion in case end() captures response headers due to
        // HttpUrlConnection.getHeaderField() calling HttpUrlConnection.getInputStream() which then
        // enters this advice again
        callDepth.getAndIncrement();
        try {
          scope.close();
          Class<? extends HttpURLConnection> connectionClass = connection.getClass();

          String requestMethod = connection.getRequestMethod();
          GetOutputStreamContext.set(
              httpUrlState.context, connectionClass, methodName, requestMethod);

          if (throwable != null) {
            if (responseCode >= 400) {
              // HttpURLConnection unnecessarily throws exception on error response.
              // None of the other http clients do this, so not recording the exception on the span
              // to be consistent with the telemetry for other http clients.
              instrumenter().end(httpUrlState.context, connection, responseCode, null);
            } else {
              instrumenter()
                  .end(
                      httpUrlState.context,
                      connection,
                      responseCode > 0 ? responseCode : httpUrlState.statusCode,
                      throwable);
            }
            httpUrlState.finished = true;
          } else if (methodName.equals("getInputStream") && responseCode > 0) {
            // responseCode field is sometimes not populated.
            // We can't call getResponseCode() due to some unwanted side-effects
            // (e.g. breaks getOutputStream).
            instrumenter().end(httpUrlState.context, connection, responseCode, null);
            httpUrlState.finished = true;
          }
        } finally {
          callDepth.decrementAndGet();
        }
      }
    }

    @Advice.OnMethodEnter(suppress = Throwable.class)
    public static AdviceScope methodEnter(@Advice.This HttpURLConnection connection) {
      CallDepth callDepth = CallDepth.forClass(HttpURLConnection.class);
      return AdviceScope.start(callDepth, connection);
    }

    @Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class)
    public static void methodExit(
        @Advice.This HttpURLConnection connection,
        @Advice.FieldValue("responseCode") int responseCode,
        @Advice.Thrown @Nullable Throwable throwable,
        @Advice.Origin("#m") String methodName,
        @Advice.Enter AdviceScope adviceScope) {
      adviceScope.end(connection, responseCode, throwable, methodName);
    }
  }

  @SuppressWarnings("unused")
  public static class GetResponseCodeAdvice {

    @Advice.OnMethodExit
    public static void methodExit(
        @Advice.This HttpURLConnection connection, @Advice.Return int returnValue) {

      HttpUrlState httpUrlState = HTTP_URL_STATE.get(connection);
      if (httpUrlState != null) {
        httpUrlState.statusCode = returnValue;
      }
    }
  }
}
